# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Records previous preprocessing operations and allows them to be repeated.

Used with object_detection.core.preprocessor. Passing a PreprocessorCache
into individual data augmentation functions or the general preprocess() function
will store all randomly generated variables in the PreprocessorCache. When
a preprocessor function is called multiple times with the same
PreprocessorCache object, that function will perform the same augmentation
on all calls.
"""

from collections import defaultdict


class PreprocessorCache(object):
    """Dictionary wrapper storing random variables generated during preprocessing.
    """

    # Constant keys representing different preprocessing functions
    ROTATION90 = 'rotation90'
    HORIZONTAL_FLIP = 'horizontal_flip'
    VERTICAL_FLIP = 'vertical_flip'
    PIXEL_VALUE_SCALE = 'pixel_value_scale'
    IMAGE_SCALE = 'image_scale'
    RGB_TO_GRAY = 'rgb_to_gray'
    ADJUST_BRIGHTNESS = 'adjust_brightness'
    ADJUST_CONTRAST = 'adjust_contrast'
    ADJUST_HUE = 'adjust_hue'
    ADJUST_SATURATION = 'adjust_saturation'
    DISTORT_COLOR = 'distort_color'
    STRICT_CROP_IMAGE = 'strict_crop_image'
    CROP_IMAGE = 'crop_image'
    PAD_IMAGE = 'pad_image'
    CROP_TO_ASPECT_RATIO = 'crop_to_aspect_ratio'
    RESIZE_METHOD = 'resize_method'
    PAD_TO_ASPECT_RATIO = 'pad_to_aspect_ratio'
    BLACK_PATCHES = 'black_patches'
    ADD_BLACK_PATCH = 'add_black_patch'
    SELECTOR = 'selector'
    SELECTOR_TUPLES = 'selector_tuples'
    SSD_CROP_SELECTOR_ID = 'ssd_crop_selector_id'
    SSD_CROP_PAD_SELECTOR_ID = 'ssd_crop_pad_selector_id'

    # 23 permitted function ids
    _VALID_FNS = [ROTATION90, HORIZONTAL_FLIP, VERTICAL_FLIP, PIXEL_VALUE_SCALE,
                  IMAGE_SCALE, RGB_TO_GRAY, ADJUST_BRIGHTNESS, ADJUST_CONTRAST,
                  ADJUST_HUE, ADJUST_SATURATION, DISTORT_COLOR, STRICT_CROP_IMAGE,
                  CROP_IMAGE, PAD_IMAGE, CROP_TO_ASPECT_RATIO, RESIZE_METHOD,
                  PAD_TO_ASPECT_RATIO, BLACK_PATCHES, ADD_BLACK_PATCH, SELECTOR,
                  SELECTOR_TUPLES, SSD_CROP_SELECTOR_ID, SSD_CROP_PAD_SELECTOR_ID]

    def __init__(self):
        self._history = defaultdict(dict)

    def clear(self):
        """Resets cache."""
        self._history = {}

    def get(self, function_id, key):
        """Gets stored value given a function id and key.

        Args:
          function_id: identifier for the preprocessing function used.
          key: identifier for the variable stored.
        Returns:
          value: the corresponding value, expected to be a tensor or
                 nested structure of tensors.
        Raises:
          ValueError: if function_id is not one of the 23 valid function ids.
        """
        if function_id not in self._VALID_FNS:
            raise ValueError('Function id not recognized: %s.' % str(function_id))
        return self._history[function_id].get(key)

    def update(self, function_id, key, value):
        """Adds a value to the dictionary.

        Args:
          function_id: identifier for the preprocessing function used.
          key: identifier for the variable stored.
          value: the value to store, expected to be a tensor or nested structure
                 of tensors.
        Raises:
          ValueError: if function_id is not one of the 23 valid function ids.
        """
        if function_id not in self._VALID_FNS:
            raise ValueError('Function id not recognized: %s.' % str(function_id))
        self._history[function_id][key] = value
